import onnx
import copy
 
# onnx 插入新的Node
def insert_node(model, insert_node, follow_up_node):
    # 根据插入Node的输出修改后续node的输入
    follow_up_node.input[0] = insert_node.output[0]
    # 找到后续Node的索引位置，并将插入节点插入到graph中
    for follow_up_node_index, _follow_up_node in enumerate(model.graph.node):
        if _follow_up_node == follow_up_node:
            print("follow_up_node_index: ", follow_up_node_index)
            model.graph.node.insert(follow_up_node_index, insert_node)
            break
 
if __name__ == '__main__':
    src_onnx_model_path = './models/onnx/my_lprnet_model.onnx'
    dst_onnx_model_path = './models/onnx/new.onnx'
    onnx_model = onnx.load(src_onnx_model_path)
    graph = onnx_model.graph
    node  = graph.node
 
    # 临时节点方便后续修改
    temp_node1 = None
    temp_node2 = None
 
    for i in range(len(node)):
        if node[i].op_type == 'Transpose':
            print(i, node[i].name)
 
    for i in range(len(node)):
        # 修改Transpose_16 维度参数
        if node[i].op_type == 'Transpose' and node[i].name == "Transpose_16":
            node[i].attribute[0].ints[1] = 2
            node[i].attribute[0].ints[2] = 3
            node[i].attribute[0].ints[3] = 1
            # 深拷贝Node，获得temp_node1,后续对temp_node1参数修改获得新的transpose_17_node
            temp_node1 = copy.deepcopy(node[i])
        # 修改Transpose_38 维度参数
        if node[i].op_type == 'Transpose' and node[i].name == "Transpose_38":
            node[i].attribute[0].ints[1] = 2
            node[i].attribute[0].ints[2] = 3
            node[i].attribute[0].ints[3] = 1
            # 深拷贝Node，获得temp_node2,后续对temp_node2参数修改获得新的Transpose_39_node
            temp_node2 = copy.deepcopy(node[i])
 
    # 修改temp_node1参数得到新的transpose_17_node
    transpose_17_node = temp_node1
    transpose_17_node.name = 'Transpose_17'
    transpose_17_node.input[0] = '79'
    transpose_17_node.output[0] = 'transpose_17_output'
 
    follow_up_node = None
    for i in range(len(node)):
        if node[i].op_type == 'Conv' and node[i].name == "Conv_17":
            follow_up_node = node[i]
            break
 
    insert_node(onnx_model, transpose_17_node, follow_up_node)
 
    # 修改temp_node2参数得到新的Transpose_39_node
    Transpose_39_node = temp_node2
    Transpose_39_node.name = 'Transpose_39'
    Transpose_39_node.input[0] = '101'
    Transpose_39_node.output[0] = 'transpose_39_output'
 
    follow_up_node = None
    for i in range(len(node)):
        if node[i].op_type == 'Conv' and node[i].name == "Conv_39":
            follow_up_node = node[i]
            break
 
    insert_node(onnx_model, Transpose_39_node, follow_up_node)
    
    onnx.checker.check_model(onnx_model)
    onnx.save(onnx_model, dst_onnx_model_path)